{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# Data Types\n",
    "\n",
    "Every tensor has a data type, which is typically `float32` in deep learning, but also could be `int8` (e.g. for model quantization) and others. The `tvm_vector_add` module we created in :numref:`ch_vector_add` only accepts `float32` tensors. Let's extend it to other data types in this section.\n",
    "\n",
    "\n",
    "## Specifying a Data Type\n",
    "\n",
    "To use a data type different to the default `float32`, we can specify it explicitly when creating the placeholders. In the following code block, we generalize the vector addition expression defined in :numref:`ch_vector_add` to accept an argument `dtype` to specify the data type. In particular, we pass `dtype` to `te.placeholder` when creating `A` and `B`. The result `C` then obtains the same data type as `A` and `B`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "origin_pos": 1,
    "tab": [
     "tvm"
    ]
   },
   "outputs": [],
   "source": [
    "import tvm\n",
    "from tvm import te\n",
    "import d2ltvm\n",
    "import numpy as np\n",
    "\n",
    "n = 100\n",
    "\n",
    "def tvm_vector_add(dtype):\n",
    "    A = te.placeholder((n,), dtype=dtype)\n",
    "    B = te.placeholder((n,), dtype=dtype)\n",
    "    C = te.compute(A.shape, lambda i: A[i] + B[i])\n",
    "    print('expression dtype:', A.dtype, B.dtype, C.dtype)\n",
    "    s = te.create_schedule(C.op)\n",
    "    return tvm.build(s, [A, B, C])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 2
   },
   "source": [
    "Let's compile a module that accepts `int32` tensors.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "origin_pos": 3,
    "tab": [
     "tvm"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "expression dtype: int32 int32 int32\n"
     ]
    }
   ],
   "source": [
    "mod = tvm_vector_add('int32')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 4
   },
   "source": [
    "Then we define a method to verify the results with a particular data type. Note that we pass a constructor that modifies the tensor data type by `astype`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "origin_pos": 5,
    "tab": [
     "tvm"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor dtype: int32 int32 int32\n"
     ]
    }
   ],
   "source": [
    "def test_mod(mod, dtype):\n",
    "    a, b, c = d2ltvm.get_abc(n, lambda x: tvm.nd.array(x.astype(dtype)))\n",
    "    print('tensor dtype:', a.dtype, b.dtype, c.dtype)\n",
    "    mod(a, b, c)\n",
    "    np.testing.assert_equal(c.asnumpy(), a.asnumpy() + b.asnumpy())\n",
    "\n",
    "test_mod(mod, 'int32')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 6
   },
   "source": [
    "Let's try other data types as well\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "origin_pos": 7,
    "tab": [
     "tvm"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "expression dtype: float16 float16 float16\n",
      "tensor dtype: float16 float16 float16\n",
      "expression dtype: float64 float64 float64\n",
      "tensor dtype: float64 float64 float64\n",
      "expression dtype: int8 int8 int8\n",
      "tensor dtype: int8 int8 int8\n",
      "expression dtype: int16 int16 int16\n",
      "tensor dtype: int16 int16 int16\n",
      "expression dtype: int64 int64 int64\n",
      "tensor dtype: int64 int64 int64\n"
     ]
    }
   ],
   "source": [
    "for dtype in ['float16', 'float64', 'int8','int16', 'int64']:\n",
    "    mod = tvm_vector_add(dtype)\n",
    "    test_mod(mod, dtype)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 8
   },
   "source": [
    "## Converting Elements Data Types\n",
    "\n",
    "Besides constructing a tensor with a particular data type, we can also cast the data type of a tensor element during the computation. The following method is the same as `tvm_vector_add` \n",
    "except that it casts the data type of A and B in `te.compute`, leaving the data type defined in `te.placeholder` as default (`float32`). Because of the casting done by `astype`, the result `C` will have the data type specified by `dtype`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "origin_pos": 9,
    "tab": [
     "tvm"
    ]
   },
   "outputs": [],
   "source": [
    "def tvm_vector_add_2(dtype):\n",
    "    A = te.placeholder((n,))\n",
    "    B = te.placeholder((n,))\n",
    "    C = te.compute(A.shape, \n",
    "                    lambda i: A[i].astype(dtype) + B[i].astype(dtype))\n",
    "    print('expression dtype:', A.dtype, B.dtype, C.dtype)\n",
    "    s = te.create_schedule(C.op)\n",
    "    return tvm.build(s, [A, B, C])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 10
   },
   "source": [
    "Then we define a similar test function to verify the results.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "origin_pos": 11,
    "tab": [
     "tvm"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "expression dtype: float32 float32 int32\n",
      "tensor dtype: float32 float32 int32\n"
     ]
    }
   ],
   "source": [
    "def test_mod_2(mod, dtype):\n",
    "    a, b, c = d2ltvm.get_abc(n)\n",
    "    # by default `get_abc` returns NumPy ndarray in float32\n",
    "    a_tvm, b_tvm = tvm.nd.array(a), tvm.nd.array(b)\n",
    "    c_tvm = tvm.nd.array(c.astype(dtype))\n",
    "    print('tensor dtype:', a_tvm.dtype, b_tvm.dtype, c_tvm.dtype)\n",
    "    mod(a_tvm, b_tvm, c_tvm)\n",
    "    np.testing.assert_equal(c_tvm.asnumpy(), a.astype(dtype) + b.astype(dtype))\n",
    "\n",
    "mod = tvm_vector_add_2('int32')\n",
    "test_mod_2(mod, 'int32')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 12
   },
   "source": [
    "## Summary\n",
    "\n",
    "- We can specify the data type by `dtype` when creating TVM placeholders.\n",
    "- The data type of a tensor element can be cast by `astype` in TVM compute.\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}